﻿import torch
import numpy as np
import scipy.linalg
from . import metric_utils
from sklearn.metrics.pairwise import cosine_similarity
from scipy import spatial
#----------------------------------------------------------------------------

def compute_simi(opts, max_real):

    mu_real_masked, sigma_real_masked = metric_utils.compute_feature_stats_for_discriminator(
        opts=opts, rel_lo=0, rel_hi=0, mask=True, capture_mean_cov=True, max_items=max_real).get_mean_cov()

    mu_real, sigma_real = metric_utils.compute_feature_stats_for_discriminator(
        opts=opts, rel_lo=0, rel_hi=0, mask=False, capture_mean_cov=True, max_items=max_real).get_mean_cov()

    simi = 1 - spatial.distance.cosine(mu_real, mu_real_masked)
    if opts.rank != 0:
        return float('nan')
    return float(simi)

#----------------------------------------------------------------------------
